"""
Docs: https://docs.swanlab.cn/guide_cloud/integration/integration-huggingface-transformers.html
"""

import os
from typing import Optional, List, Dict, Union, Any
import logging
import swanlab

try:
    from transformers.trainer_callback import TrainerCallback
    from transformers import modelcard
except ImportError:
    raise RuntimeError(
        "This contrib module requires Transformers to be installed. "
        "Please install it with command: \n pip install transformers"
    )


def rewrite_logs(d):
    new_d = {}
    eval_prefix = "eval_"
    eval_prefix_len = len(eval_prefix)
    test_prefix = "test_"
    test_prefix_len = len(test_prefix)
    for k, v in d.items():
        if k.startswith(eval_prefix):
            new_d["eval/" + k[eval_prefix_len:]] = v
        elif k.startswith(test_prefix):
            new_d["test/" + k[test_prefix_len:]] = v
        else:
            new_d["train/" + k] = v
    return new_d


class SwanLabCallback(TrainerCallback):
    def __init__(
        self,
        project: Optional[str] = None,
        workspace: Optional[str] = None,
        experiment_name: Optional[str] = None,
        description: Optional[str] = None,
        logdir: Optional[str] = None,
        mode: Optional[str] = None,
        **kwargs: Any,
    ):
        """
        To use the `SwanLabCallback`, pass it into the `callback` parameter when initializing the `transformers.Trainer`.
        This allows the Trainer to utilize SwanLab's logging and monitoring functionalities during the training process.
        Parameters same with `swanlab.init`. Finds more informations 
        [here](https://docs.swanlab.cn/api/py-init.html#swanlab-init)

        Parameters
        ----------
        project : str, optional
            The project name of the current experiment, the default is None,
            which means the current project name is the same as the current working directory.
        workspace : str, optional
            Where the current project is located, it can be an organization or a user (currently only supports yourself).
            The default is None, which means the current entity is the same as the current user.
        experiment_name : str, optional
            The experiment name you currently have open. If this parameter is not provided,
            SwanLab will generate one for you by default.
        description : str, optional
            The experiment description you currently have open,
            used for a more detailed introduction or labeling of the current experiment.
            If you do not provide this parameter, you can modify it later in the web interface.
        logdir : str, optional
            The folder will store all the log information generated during the execution of SwanLab.
            If the parameter is None,
            SwanLab will generate a folder named "swanlog" in the same path as the code execution to store the data.
            If you want to visualize the generated log files,
            simply run the command `swanlab watch` in the same path where the code is executed
            (without entering the "swanlog" folder).
            You can also specify your own folder, but you must ensure that the folder exists and preferably does not contain
            anything other than data generated by Swanlab.
            In this case, if you want to view the logs,
            you must use something like `swanlab watch -l ./your_specified_folder` to specify the folder path.
        mode : str, optional
            Allowed values are 'cloud', 'cloud-only', 'local', 'disabled'.
            If the value is 'cloud', the data will be uploaded to the cloud and the local log will be saved.
            If the value is 'cloud-only', the data will only be uploaded to the cloud and the local log will not be saved.
            If the value is 'local', the data will only be saved locally and will not be uploaded to the cloud.
            If the value is 'disabled', the data will not be saved or uploaded, just parsing the data.
        """
        self._swanlab = swanlab
        self._initialized = False
        self._log_model = os.getenv("SWANLAB_LOG_MODEL", None)

        # for callback args
        self._swanlab_init: Dict[str, Any] = {
            "project": project,
            "workspace": workspace,
            "experiment_name": experiment_name,
            "description": description,
            "logdir": logdir,
            "mode": mode,
        }
        self._swanlab_init.update(**kwargs)

    def setup(self, args, state, model, **kwargs):
        """
        Setup the optional SwanLab (*swanlab*) integration.

        One can subclass and override this method to customize the setup if needed. Find more information
        [here](https://docs.swanlab.cn/guide_cloud/integration/integration-huggingface-transformers.html).

        You can also override the following environment variables. Find more information about environment
        variables [here](https://docs.swanlab.cn/en/api/environment-variable.html#environment-variables)

        Environment:
        - **SWANLAB_API_KEY** (`str`, *optional*, defaults to `None`):
            Cloud API Key. During login, this environment variable is checked first. If it doesn't exist, the system
            checks if the user is already logged in. If not, the login process is initiated.

                - If a string is passed to the login interface, this environment variable is ignored.
                - If the user is already logged in, this environment variable takes precedence over locally stored
                login information.

        - **SWANLAB_PROJECT** (`str`, *optional*, defaults to `None`):
            Set this to a custom string to store results in a different project. If not specified, the name of the current
            running directory is used.

        - **SWANLAB_LOG_DIR** (`str`, *optional*, defaults to `swanlog`):
            This environment variable specifies the storage path for log files when running in local mode.
            By default, logs are saved in a folder named swanlog under the working directory.

        - **SWANLAB_MODE** (`Literal["local", "cloud", "disabled"]`, *optional*, defaults to `cloud`):
            SwanLab's parsing mode, which involves callbacks registered by the operator. Currently, there are three modes:
            local, cloud, and disabled. Note: Case-sensitive. Find more information
            [here](https://docs.swanlab.cn/en/api/py-init.html#swanlab-init)

        - **SWANLAB_LOG_MODEL** (`str`, *optional*, defaults to `None`):
            SwanLab does not currently support the save mode functionality.This feature will be available in a future
            release

        - **SWANLAB_WEB_HOST** (`str`, *optional*, defaults to `None`):
            Web address for the SwanLab cloud environment for private version (its free)

        - **SWANLAB_API_HOST** (`str`, *optional*, defaults to `None`):
            API address for the SwanLab cloud environment for private version (its free)

        """
        self._initialized = True

        if state.is_world_process_zero:
            logging.info('Automatic SwanLab logging enabled, to disable set os.environ["SWANLAB_MODE"] = "disabled"')
            combined_dict = {**args.to_dict()}

            if hasattr(model, "config") and model.config is not None:
                model_config = model.config if isinstance(model.config, dict) else model.config.to_dict()
                combined_dict = {**model_config, **combined_dict}
            if hasattr(model, "peft_config") and model.peft_config is not None:
                peft_config = model.peft_config
                combined_dict = {**{"peft_config": peft_config}, **combined_dict}
            trial_name = state.trial_name
            init_args = {}
            if trial_name is not None:
                init_args["experiment_name"] = f"{args.run_name}-{trial_name}"
            elif args.run_name is not None:
                init_args["experiment_name"] = args.run_name
            init_args["project"] = os.getenv("SWANLAB_PROJECT", None)

            if self._swanlab.get_run() is None:
                # ATTENTION: little differents in transformers
                init_args.update(self._swanlab_init)
                self._swanlab.init(
                    **init_args,
                )
            # show transformers logo!
            self._swanlab.config["FRAMEWORK"] = "🤗transformers"
            # add config parameters (run may have been created manually)
            self._swanlab.config.update(combined_dict)

            # add number of model parameters to swanlab config
            try:
                self._swanlab.config.update({"model_num_parameters": model.num_parameters()})
                # get peft model parameters
                if type(model).__name__ == "PeftModel" or type(model).__name__ == "PeftMixedModel":
                    trainable_params, all_param = model.get_nb_trainable_parameters()
                    self._swanlab.config.update({"peft_model_trainable_params": trainable_params})
                    self._swanlab.config.update({"peft_model_all_param": all_param})
            except AttributeError:
                logging.info("Could not log the number of model parameters in SwanLab due to an AttributeError.")

            # log the initial model architecture to an artifact
            if self._log_model is not None:
                logging.warning(
                    "SwanLab does not currently support the save mode functionality. "
                    "This feature will be available in a future release."
                )
                badge_markdown = (
                    f'[<img src="https://raw.githubusercontent.com/SwanHubX/assets/main/badge1.svg"'
                    f' alt="Visualize in SwanLab" height="28'
                    f'0" height="32"/>]({self._swanlab.get_run().public.cloud.exp_url})'
                )

                modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"

    def update_config(self, config: Dict[str, Any]):
        """
        Update the SwanLab config.
        
        Example:
        ```python
        swanlab_callback = SwanLabCallback(...)
        swanlab_callback.update_config({"model_name": "qwen"})
        trainer = Trainer(..., callbacks=[swanlab_callback])
        ```
        """
        self._swanlab.config.update(config)

    def on_train_begin(self, args, state, control, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model, **kwargs)

    def on_train_end(self, args, state, control, model=None, processing_class=None, **kwargs):
        if self._log_model is not None and self._initialized and state.is_world_process_zero:
            logging.warning(
                "SwanLab does not currently support the save mode functionality. "
                "This feature will be available in a future release."
            )

    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
        single_value_scalars = [
            "train_runtime",
            "train_samples_per_second",
            "train_steps_per_second",
            "train_loss",
            "total_flos",
        ]

        if not self._initialized:
            self.setup(args, state, model)
        if state.is_world_process_zero:
            for k, v in logs.items():
                if k in single_value_scalars:
                    self._swanlab.log({f"single_value/{k}": v}, step=state.global_step)
            non_scalar_logs = {k: v for k, v in logs.items() if k not in single_value_scalars}
            non_scalar_logs = rewrite_logs(non_scalar_logs)
            self._swanlab.log({**non_scalar_logs, "train/global_step": state.global_step}, step=state.global_step)

    def on_save(self, args, state, control, **kwargs):
        if self._log_model is not None and self._initialized and state.is_world_process_zero:
            logging.warning(
                "SwanLab does not currently support the save mode functionality. "
                "This feature will be available in a future release."
            )

    def on_predict(self, args, state, control, metrics, **kwargs):
        if not self._initialized:
            self.setup(args, state, **kwargs)
        if state.is_world_process_zero:
            metrics = rewrite_logs(metrics)
            self._swanlab.log(metrics)
